#include "codegen.h"
#include "llvm/IR/Verifier.h"

// 在 ir 层面就进行了常量折叠
/*
declare i32 @printf(ptr, ...)

define i32 @main() {
entry:
  %0 = call i32 (ptr, ...) @printf(ptr @0, i32 17)
  %1 = call i32 (ptr, ...) @printf(ptr @1, i32 0)
  %2 = call i32 (ptr, ...) @printf(ptr @2, i32 2)
  ret i32 0
}
*/

llvm::Value *CodeGen::VisitProgram(Program *p)
{
    // 创建 printf 函数
    auto printFunctionType = llvm::FunctionType::get(irBuilder.getInt32Ty(), {irBuilder.getInt8PtrTy()}, true);
    auto printFunction = llvm::Function::Create(printFunctionType, llvm::GlobalValue::ExternalLinkage, "printf", module.get()); 
    // 创建 main 函数
    auto mainFunctionType = llvm::FunctionType::get(irBuilder.getInt32Ty(), false);
    auto mainFunction = llvm::Function::Create(mainFunctionType, llvm::GlobalValue::ExternalLinkage, "main", module.get());
    // 创建 main 函数的基本块
    llvm::BasicBlock *entryBlock = llvm::BasicBlock::Create(context, "entry", mainFunction);
    // 设置该基本块作为指令的入口
    irBuilder.SetInsertPoint(entryBlock);
    for (auto expr : p->ExprVec)
    {
        llvm::Value *v = expr->Accept(this);
        irBuilder.CreateCall(printFunction, {irBuilder.CreateGlobalStringPtr("expr value: %d\n"), v});
    }
    // 创建返回值
    llvm::Value *ret = irBuilder.CreateRet(irBuilder.getInt32(0));

    llvm::verifyFunction(*mainFunction);

    module->print(llvm::outs(), nullptr);
    return ret;
}

llvm::Value *CodeGen::VisitBinaryExpr(BinaryExpr *binaryExpr)
{
    auto left = binaryExpr->left->Accept(this);
    auto right = binaryExpr->right->Accept(this);

    switch (binaryExpr->op)
    {
    case OpCode::add:
    {
        return irBuilder.CreateNSWAdd(left, right, "add"); // CreateNSW... 是防止溢出行为的
    }
    case OpCode::sub:
    {
        return irBuilder.CreateNSWSub(left, right, "sub");
    }
    case OpCode::mul:
    {
        return irBuilder.CreateNSWMul(left, right, "mul");
    }
    case OpCode::div:
    {
        return irBuilder.CreateSDiv(left, right, "div");
    }
    default:
    {
        break;
    }
    }
    return nullptr;
}

llvm::Value *CodeGen::VisitFactorExpr(FactorExpr *factorExpr)
{
    return irBuilder.getInt32(factorExpr->number);
}